{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2025, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}
#include "pybind_array_tools.h"
#include "solver.h"

namespace caspar {

inline void add_solver_pybinding(pybind11::module_ module) {
  py::class_<{{solver.struct_name}}>(module, "{{solver.struct_name}}",
                                     "Class for solving Factor Graphs.")
      .def(py::init<SolverParams,
           {% for thing in solver.size_contributors %}
           size_t{{ ", " if not loop.last else "" }}
           {% endfor %}>(),
           py::arg("params"),
           py::kw_only(),
           {% for thing in solver.size_contributors %}
           py::arg("{{num_arg_key(thing)}}") = 0{{ ", " if not loop.last else "" }}
           {% endfor %}
       )

      .def("set_params", &{{solver.struct_name}}::set_params)
      .def("solve", &{{ solver.struct_name }}::solve, py::call_guard<py::gil_scoped_release>(),
           py::arg("print_progress") = false)
      .def("finish_indices", &{{solver.struct_name}}::finish_indices)
      .def("get_allocation_size", &{{solver.struct_name}}::get_allocation_size)

      {% for nodetype in solver.node_types %}
      .def("set_{{nodetype.__name__}}_num", &{{solver.struct_name}}::set_{{nodetype.__name__}}_num)
      .def("set_{{nodetype.__name__}}_nodes_from_stacked_host",
           []({{ solver.struct_name }} &solver, pybind11::object stacked_data, size_t offset) {
             AssertHostMemory(stacked_data);
             solver.set_{{nodetype.__name__}}_nodes_from_stacked_host(
                 AsFloatPtr(stacked_data),
                 offset,
                 GetNumRows(stacked_data));
           },
           pybind11::arg("stacked_nodes"), pybind11::arg("offset") = 0
      )
      .def("set_{{nodetype.__name__}}_nodes_from_stacked_device",
           []({{ solver.struct_name }} &solver, pybind11::object stacked_data, size_t offset) {
             AssertDeviceMemory(stacked_data);
             solver.set_{{nodetype.__name__}}_nodes_from_stacked_device(
                 AsFloatPtr(stacked_data),
                 offset,
                 GetNumRows(stacked_data));
           },
           pybind11::arg("stacked_nodes"), pybind11::arg("offset") = 0
      )
      .def("get_{{nodetype.__name__}}_nodes_to_stacked_host",
           []({{ solver.struct_name }} &solver, pybind11::object stacked_data, size_t offset) {
             AssertHostMemory(stacked_data);
             solver.get_{{nodetype.__name__}}_nodes_to_stacked_host(
                 AsFloatPtr(stacked_data),
                 offset,
                 GetNumRows(stacked_data));
           },
           pybind11::arg("stacked_nodes"), pybind11::arg("offset") = 0
      )
      .def("get_{{nodetype.__name__}}_nodes_to_stacked_device",
           []({{ solver.struct_name }} &solver, pybind11::object stacked_data, size_t offset) {
             AssertDeviceMemory(stacked_data);
             solver.get_{{nodetype.__name__}}_nodes_to_stacked_device(
                 AsFloatPtr(stacked_data),
                 offset,
                 GetNumRows(stacked_data));
           },
           pybind11::arg("stacked_nodes"), pybind11::arg("offset") = 0
      )

      {% endfor %}
      {% for factor in solver.factors %}
      .def("set_{{factor.name}}_num", &{{solver.struct_name}}::set_{{factor.name}}_num)

      {% for argname, argtype in factor.node_arg_types.items() %}
      .def("set_{{factor.name}}_{{argname}}_indices_from_host",
           []({{ solver.struct_name }} &solver, pybind11::object indices) {
             AssertHostMemory(indices);
             solver.set_{{factor.name}}_{{argname}}_indices_from_host(
                 AsUintPtr(indices),
                 GetNumRows(indices));
           }
      )
      .def("set_{{factor.name}}_{{argname}}_indices_from_device",
           []({{ solver.struct_name }} &solver, pybind11::object indices) {
             AssertDeviceMemory(indices);
             solver.set_{{factor.name}}_{{argname}}_indices_from_device(
                 AsUintPtr(indices),
                 GetNumRows(indices));
           }
      )
      {% endfor %}
      {% for argname, argtype in factor.const_arg_types.items() %}
      .def("set_{{factor.name}}_{{argname}}_data_from_stacked_device",
           []({{ solver.struct_name }} &solver, pybind11::object stacked_data, size_t offset) {
             AssertDeviceMemory(stacked_data);
             solver.set_{{factor.name}}_{{argname}}_data_from_stacked_device(
                 AsFloatPtr(stacked_data),
                 offset,
                 GetNumRows(stacked_data));
           },
           pybind11::arg("stacked_data"), pybind11::arg("offset") = 0
      )
      .def("set_{{factor.name}}_{{argname}}_data_from_stacked_host",
           []({{ solver.struct_name }} &solver, pybind11::object stacked_data, size_t offset) {
             AssertHostMemory(stacked_data);
             solver.set_{{factor.name}}_{{argname}}_data_from_stacked_host(
                 AsFloatPtr(stacked_data),
                 offset,
                 GetNumRows(stacked_data));
           },
           pybind11::arg("stacked_data"), pybind11::arg("offset") = 0
      )
      {% endfor %}
      {% endfor %}
      ;
}

} // namespace caspar
